import cv2
import math
import numpy as np
import pandas as pd
import scipy
from scipy.spatial import distance
from PIL import Image, ImageEnhance
from tqdm import tqdm


IMAGE_DIM = 224
KILN_SIZE = 3500

# Zoom level of imagery
ZOOM_LEVEL = 17
# Calculate based on difference between coordinates of image tiles
LAT_INCREMENT = float(24.0840818 - 24.08658926) / IMAGE_DIM
LNG_INCREMENT = float(90.34057617 - 90.33782959) / IMAGE_DIM


### Mask images module ###
def mask_images(results, results_cam, test_files, kiln_model, model_type='vgg16_cam'):
    img_path_to_mask = {}
    for i, result in tqdm(enumerate(results), total=len(results)):
        img_path = test_files[i]

        pred_vec = result['dense_1'][0]
        last_conv_output = results_cam[i]['model_1']

        # change dimensions of last convolutional output to 7 x 7 x 2048 for ResNet, 7 x 7 x 512 for VGG16
        last_conv_output = np.squeeze(last_conv_output)

        # get model's prediction (number between 0 and 999, inclusive)
        pred = np.argmax(pred_vec)

        # bilinear upsampling to resize each filtered image to size of original image
        mat_for_mult = scipy.ndimage.zoom(
            last_conv_output, (16, 16, 1), order=1)  # dim: 224 x 224 x 2048 or 512

        # get AMP layer weights
        all_amp_layer_weights = kiln_model.get_variable_value('dense_1/kernel')
        # dim: (2048,) or (512,)
        amp_layer_weights = all_amp_layer_weights[:, pred]

        # get class activation map for object class that is predicted to be in the image
        final_output = np.dot(mat_for_mult.reshape(
            (224 * 224, 512)), amp_layer_weights).reshape(224, 224)  # dim: 224 x 224

        # First, subtract mean and then clip values to keep only values > mean
        mean_val = np.mean(final_output[np.nonzero(final_output)])
        final_output = final_output - mean_val
        final_output = np.clip(final_output, 0, None)

        # Now, take the mean of only the nonzero values, subtract it, and clip
        # the values below 0 to keep the strongest signals
        mean_val = np.mean(final_output[np.nonzero(final_output)])
        final_output = final_output - mean_val
        final_output = np.clip(final_output, 0, None)

        # REPEAT
        final_output = final_output - np.mean(final_output)
        final_output = np.clip(final_output, 0, None)
        mean_val = np.mean(final_output[np.nonzero(final_output)])
        final_output = final_output - mean_val
        final_output = np.clip(final_output, 0, None)

        # Normalize to make values range 0 to 255
        norm = final_output
        cv2.normalize(final_output, norm, 0, 255, cv2.NORM_MINMAX)
        img_arr = np.array(norm, dtype=np.uint8)

        # Perform opening (erosion followed by dilation) to remove noise
        kernel = np.ones((15, 15), np.uint8)
        th_otsu = cv2.morphologyEx(img_arr, cv2.MORPH_OPEN, kernel)

        mask = th_otsu
        img_path_to_mask[img_path] = mask

    return img_path_to_mask


def create_connected_component_image(cc, yes_df, all_df, img_path_to_mask):
    '''Returns a masked image with the pixel clusters.'''
    rows = [v.row for v in cc]
    cols = [v.col for v in cc]

    width = max(cols) - min(cols) + 1
    height = max(rows) - min(rows) + 1
    max_dim = max(width, height)

    # Lat/Long is that of top left image
    lat = all_df.loc[(all_df['row'] == min(rows))].iloc[0]['lat']
    lon = all_df.loc[(all_df['col'] == min(cols))].iloc[0]['lon']

    df = yes_df
    # Get all images in this range
    df_relevant = df.loc[(df['row'] >= min(rows)) &
                         (df['row'] <= max(rows)) &
                         (df['col'] >= min(cols)) &
                         (df['col'] <= max(cols))]

    image_of_full_area = Image.new(
        'RGB', (IMAGE_DIM * max_dim, IMAGE_DIM * max_dim), color=0)
    image_of_full_area_mask = Image.new(
        'RGB', (IMAGE_DIM * max_dim, IMAGE_DIM * max_dim), color=0)

    for r in range(min(rows), max(rows) + 1):
        for c in range(min(cols), max(cols) + 1):
            # Find the image tile that corresponds to this row, col in
            # entire grid of country
            tile = df_relevant.loc[(df_relevant['row'] == r) & (
                df_relevant['col'] == c)]
            if not tile.empty:
                if tile['prediction'].values[0] == 'yeskiln':
                    # Find where to paste this image patch
                    offset_height = (r - min(rows)) * IMAGE_DIM
                    offset_width = (c - min(cols)) * IMAGE_DIM

                    # Make image of full area with the original tiles
                    tile_img = Image.open(tile['path'].values[0]).resize(
                        (IMAGE_DIM, IMAGE_DIM))
                    image_of_full_area.paste(
                        tile_img, box=(offset_width, offset_height))

                    # Make image of full area with CAM output
                    tile_mask = img_path_to_mask[tile['path'].values[0]]
                    tile_mask = Image.fromarray(tile_mask)
                    image_of_full_area_mask.paste(
                        tile_mask, box=(offset_width, offset_height))
                    tile_img.close()

    return image_of_full_area, image_of_full_area_mask, lat, lon


def postprocess_connected_component_mask(image_of_full_area_mask):
    image_of_full_area_mask = np.array(
        image_of_full_area_mask, dtype=np.uint8)
    image_of_full_area_mask = np.array(cv2.cvtColor(
        image_of_full_area_mask, cv2.COLOR_BGR2GRAY), dtype=np.uint8)

    # Pass through median filter to amplify/concentrate signals
    image_of_full_area_mask = scipy.ndimage.median_filter(
        image_of_full_area_mask, size=10)

    # Increase contrast to amplify signals and bridge gaps
    image_of_full_area_mask = Image.fromarray(image_of_full_area_mask)
    enhancer = ImageEnhance.Contrast(image_of_full_area_mask)
    image_of_full_area_mask = enhancer.enhance(1.5)

    # Smooth using Gaussian filter
    image_of_full_area_mask = scipy.ndimage.gaussian_filter(
        image_of_full_area_mask, sigma=10)

    # Repeat
    image_of_full_area_mask = scipy.ndimage.median_filter(
        image_of_full_area_mask, size=10)

    image_of_full_area_mask = Image.fromarray(image_of_full_area_mask)
    enhancer = ImageEnhance.Contrast(image_of_full_area_mask)
    image_of_full_area_mask = enhancer.enhance(1.5)

    image_of_full_area_mask = scipy.ndimage.gaussian_filter(
        image_of_full_area_mask, sigma=10)

    # Repeat
    image_of_full_area_mask = scipy.ndimage.median_filter(
        image_of_full_area_mask, size=10)

    image_of_full_area_mask = Image.fromarray(image_of_full_area_mask)
    enhancer = ImageEnhance.Contrast(image_of_full_area_mask)
    image_of_full_area_mask = enhancer.enhance(2.0)

    # Erode
    kernel = np.ones((5, 5), np.uint8)
    image_of_full_area_mask = np.array(image_of_full_area_mask, dtype=np.uint8)
    image_of_full_area_mask = cv2.erode(
        image_of_full_area_mask, kernel, iterations=1)
    _, image_of_full_area_mask = cv2.threshold(
        image_of_full_area_mask, 0, 255, cv2.THRESH_OTSU)

    image_of_full_area_mask = Image.fromarray(image_of_full_area_mask)
    return image_of_full_area_mask


class Vertex:
    def __init__(self, row, col, new_id):
        self.row = row
        self.col = col
        self.new_id = new_id
        self.visited = False

    def mark_visited(self):
        self.visited = True


class Graph:
    def __init__(self, num_vertices, df):  # num_vertices == num_yeskiln
        self.num_vertices = num_vertices
        self.df = df  # Pandas DataFrame containing info about each yeskiln image
        self.vertices = {}  # maps new_id to vertex
        self.adj = {}  # maps vertex to adjacent vertices

    def add_edge(self, v1, v2):
        if v1.new_id not in self.vertices:
            self.adj[v1] = set()
            self.vertices[v1.new_id] = v1
        if v2 and v2.new_id not in self.vertices:
            self.adj[v2] = set()
            self.vertices[v2.new_id] = v2
        if v2:
            self.adj[self.vertices[v1.new_id]].add(self.vertices[v2.new_id])
            self.adj[self.vertices[v2.new_id]].add(self.vertices[v1.new_id])

    def find_connected_components(self):
        connected_components = []
        for v in self.adj.keys():
            if not v.visited:
                connected_comp = self.DFS_helper(v, set())
                connected_components.append(connected_comp)
        return connected_components

    def DFS_helper(self, v, connected):
        v.mark_visited()
        connected.add(v)
        for neighbor in self.adj[v]:
            if not neighbor.visited:
                connected = self.DFS_helper(neighbor, connected)
        return connected

    def get_neighbors(self, v):
        '''Returns the new_ids of the vertices neighboring the given vertex.'''
        # Find the neighbors
        df = self.df
        above = df.loc[(df['row'] == v.row - 1) & (df['col'] == v.col)]
        below = df.loc[(df['row'] == v.row + 1) & (df['col'] == v.col)]
        left = df.loc[(df['row'] == v.row) & (df['col'] == v.col - 1)]
        right = df.loc[(df['row'] == v.row) & (df['col'] == v.col + 1)]
        neighbors = [above, below, left, right]
        neighbors = [x for x in neighbors if not x.empty]
        for i, x in enumerate(neighbors):
            row = x['row'].values[0]
            col = x['col'].values[0]
            new_id = x['new_id'].values[0]
            neighbors[i] = Vertex(row, col, new_id)
        return neighbors


### Centroids to coordinates module ###
def convert_centroid_to_latlon(x, y, lat, lng):
    pointLat = lat + y * LAT_INCREMENT
    pointLng = lng + x * LNG_INCREMENT
    return (pointLat, pointLng)

def recenter(image_of_full_area, x, y):
    full_area_arr = np.array(image_of_full_area)
    x_start = max(0, x-IMAGE_DIM/3)
    x_end = min(full_area_arr.shape[0]-1, x+IMAGE_DIM/3)
    y_start = max(0, y-IMAGE_DIM/3)
    y_end = min(full_area_arr.shape[1]-1, y+IMAGE_DIM/3)

    recentered = full_area_arr[x_start:x_end, y_start:y_end, :]
    recentered_full = np.zeros((IMAGE_DIM,IMAGE_DIM,3))
    qtr = int(IMAGE_DIM/6)
    recentered_full[qtr:qtr+recentered.shape[0], qtr:qtr+recentered.shape[1], :] = recentered
    recentered = Image.fromarray((recentered_full).astype(np.uint8))
    return recentered


def get_coordinates(image_of_full_area, full_area_mask_arr, lat, lng, shape_model, shape_input_name, shape_model_checkpoint, cc_sizes=[]):
    # Find locations
    structure = 255 * np.ones((3, 3), dtype=np.int)
    labeled, ncomponents = scipy.ndimage.measurements.label(
        full_area_mask_arr, structure)
    centroids = []
    coords = []
    recentered_paths = []
    new_ids = []

    for n in range(ncomponents):
        indices = np.indices(full_area_mask_arr.shape).T[:, :, [1, 0]]
        cluster = indices[labeled == (n + 1)]
        cc_sizes.append(len(cluster))
        num_kilns = int(float(len(cluster)) / KILN_SIZE)

        min_point = np.min(cluster[cluster[:, 0] == np.min(cluster[:, 0])], 0)
        max_point = np.max(cluster[cluster[:, 0] == np.max(cluster[:, 0])], 0)

        # Split up the kiln cluster based on estimated number of kilns
        if num_kilns <= 1:
            x = int((min_point[0] + max_point[0]) / 2)
            y = int((min_point[1] + max_point[1]) / 2)
            closest_idx = scipy.spatial.distance.cdist(
                [(x, y)], cluster).argmin()
            x, y = cluster[closest_idx]

            _lat, _lon = convert_centroid_to_latlon(y, x, lat, lng)
            recentered = recenter(image_of_full_area, x, y)
            recentered_file_path = 'test_imgs/' + str(_lat) + ',' + str(_lon) + '_recentered.jpeg'
            recentered.save(recentered_file_path)

            centroids.append((x, y))
            coords.append((_lat, _lon))
            recentered_paths.append(recentered_file_path)
            new_ids.append(str(_lat) + '_' + str(_lon))

        else:
            dist_between_pts_x = int(
                float(max_point[0] - min_point[0]) / num_kilns)
            dist_between_pts_y = int(
                float(max_point[1] - min_point[1]) / num_kilns)

            # After first point down, can just add dist_between_pts_x and dist_between_pts_y
            # to get the next num_kilns - 1 points
            first_point_x = min_point[0] + int(float(dist_between_pts_x) / 2)
            first_point_y = min_point[1] + int(float(dist_between_pts_y) / 2)
            centroid_k = (first_point_x, first_point_y)
            closest_idx = distance.cdist([centroid_k], cluster).argmin()
            x, y = cluster[closest_idx]

            _lat, _lon = convert_centroid_to_latlon(y, x, lat, lng)
            recentered = recenter(image_of_full_area, x, y)
            recentered_file_path = 'test_imgs/' + str(_lat) + ',' + str(_lon) + '_recentered.jpeg'
            recentered.save(recentered_file_path)

            centroids.append((x, y))
            coords.append((_lat, _lon))
            recentered_paths.append(recentered_file_path)
            new_ids.append(str(_lat) + '_' + str(_lon))
            min_point = centroid_k

            for k in range(num_kilns - 1):
                centroid_k = (min_point[0] + dist_between_pts_x,
                              min_point[1] + dist_between_pts_y)
                closest_idx = distance.cdist([centroid_k], cluster).argmin()
                x, y = cluster[closest_idx]

                _lat, _lon = convert_centroid_to_latlon(y, x, lat, lng)
                recentered = recenter(image_of_full_area, x, y)
                recentered_file_path = 'test_imgs/' + str(_lat) + ',' + str(_lon) + '_recentered.jpeg'
                recentered.save(recentered_file_path)

                centroids.append((x, y))
                coords.append((_lat, _lon))
                recentered_paths.append(recentered_file_path)
                new_ids.append(str(_lat) + '_' + str(_lon))
                min_point = centroid_k

    labels = ['fck'] * len(new_ids)  # Add a shape label
    return centroids, coords, cc_sizes, recentered_paths, new_ids, labels
